/**
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package org.apache.pinot.broker.routing.instanceselector;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import java.time.Clock;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import javax.annotation.Nullable;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.helix.model.ExternalView;
import org.apache.helix.model.IdealState;
import org.apache.helix.store.zk.ZkHelixPropertyStore;
import org.apache.helix.zookeeper.datamodel.ZNRecord;
import org.apache.pinot.broker.routing.adaptiveserverselector.AdaptiveServerSelector;
import org.apache.pinot.common.assignment.InstancePartitions;
import org.apache.pinot.common.assignment.InstancePartitionsUtils;
import org.apache.pinot.common.metrics.BrokerMetrics;
import org.apache.pinot.spi.config.table.TableType;
import org.apache.pinot.spi.config.table.assignment.InstancePartitionsType;
import org.apache.pinot.spi.utils.builder.TableNameBuilder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;


/**
 * Instance selector for multi-stage queries which can ensure that Colocated Tables always leverage Colocated Join
 * whenever possible. To achieve this, this instance-selector uses InstancePartitions (IP) to determine replica-groups,
 * as opposed to IdealState used by other instance-selectors. Moreover, this also uses the requestId generated by
 * Pinot broker to determine the replica-group picked for each table involved in the query, as opposed to using a
 * member variable. There may be scenarios where an instance in the chosen replica-group is down. In that case, this
 * strategy will try to pick another replica-group. For realtime tables, this strategy uses only CONSUMING partitions.
 * This is feature is in <strong>Beta</strong>.
 */
public class MultiStageReplicaGroupSelector extends BaseInstanceSelector {
  private static final Logger LOGGER = LoggerFactory.getLogger(MultiStageReplicaGroupSelector.class);

  private volatile InstancePartitions _instancePartitions;

  public MultiStageReplicaGroupSelector(String tableNameWithType, ZkHelixPropertyStore<ZNRecord> propertyStore,
      BrokerMetrics brokerMetrics, @Nullable AdaptiveServerSelector adaptiveServerSelector, Clock clock,
      boolean useFixedReplica, long newSegmentExpirationTimeInSeconds) {
    super(tableNameWithType, propertyStore, brokerMetrics, adaptiveServerSelector, clock, useFixedReplica,
        newSegmentExpirationTimeInSeconds);
  }

  @Override
  public void init(Set<String> enabledInstances, IdealState idealState, ExternalView externalView,
      Set<String> onlineSegments) {
    super.init(enabledInstances, idealState, externalView, onlineSegments);
    _instancePartitions = getInstancePartitions();
  }

  @Override
  public void onInstancesChange(Set<String> enabledInstances, List<String> changedInstances) {
    super.onInstancesChange(enabledInstances, changedInstances);
    _instancePartitions = getInstancePartitions();
  }

  @Override
  public void onAssignmentChange(IdealState idealState, ExternalView externalView, Set<String> onlineSegments) {
    super.onAssignmentChange(idealState, externalView, onlineSegments);
    _instancePartitions = getInstancePartitions();
  }

  @Override
  Pair<Map<String, String>, Map<String, String>> select(List<String> segments, int requestId,
      SegmentStates segmentStates, Map<String, String> queryOptions) {
    // Create a copy of InstancePartitions to avoid race-condition with event-listeners above.
    InstancePartitions instancePartitions = _instancePartitions;
    int replicaGroupSelected;
    if (isUseFixedReplica(queryOptions)) {
      // When using sticky routing, we want to iterate over the instancePartitions in order to ensure deterministic
      // selection of replica group across queries i.e. same instance replica group id is picked each time.
      // Since the instances within a selected replica group are iterated in order, the assignment within a selected
      // replica group is guaranteed to be deterministic.
      // Note: This can cause major hotspots in the cluster.
      replicaGroupSelected = 0;
    } else {
      replicaGroupSelected = requestId % instancePartitions.getNumReplicaGroups();
    }
    for (int iteration = 0; iteration < instancePartitions.getNumReplicaGroups(); iteration++) {
      int replicaGroup = (replicaGroupSelected + iteration) % instancePartitions.getNumReplicaGroups();
      try {
        return tryAssigning(segments, segmentStates, instancePartitions, replicaGroup);
      } catch (Exception e) {
        LOGGER.warn("Unable to select replica-group {} for table: {}", replicaGroup, _tableNameWithType, e);
      }
    }
    throw new RuntimeException(
        String.format("Unable to find any replica-group to serve table: %s", _tableNameWithType));
  }

  /**
   * Returns a map from the segmentName to the corresponding server in the given replica-group. If the is not enabled,
   * we throw an exception.
   */
  private Pair<Map<String, String>, Map<String, String>> tryAssigning(List<String> segments,
      SegmentStates segmentStates, InstancePartitions instancePartitions, int replicaId) {
    Set<String> instanceLookUpSet = new HashSet<>();
    for (int partition = 0; partition < instancePartitions.getNumPartitions(); partition++) {
      List<String> instances = instancePartitions.getInstances(partition, replicaId);
      instanceLookUpSet.addAll(instances);
    }
    Map<String, String> segmentToSelectedInstanceMap = new HashMap<>();
    Map<String, String> optionalSegmentToInstanceMap = new HashMap<>();
    for (String segment : segments) {
      List<SegmentInstanceCandidate> candidates = segmentStates.getCandidates(segment);
      // If candidates are null, we will throw an exception and log a warning.
      Preconditions.checkState(candidates != null, "Failed to find servers for segment: %s", segment);
      boolean found = false;
      // candidates array is always sorted
      for (SegmentInstanceCandidate candidate : candidates) {
        String instance = candidate.getInstance();
        if (instanceLookUpSet.contains(instance)) {
          found = true;
          // This can only be offline when it is a new segment. And such segment is marked as optional segment so that
          // broker or server can skip it upon any issue to process it.
          if (candidate.isOnline()) {
            segmentToSelectedInstanceMap.put(segment, instance);
          } else {
            optionalSegmentToInstanceMap.put(segment, instance);
          }
          break;
        }
      }
      if (!found) {
        throw new RuntimeException(String.format("Unable to find an enabled instance for segment: %s", segment));
      }
    }
    return Pair.of(segmentToSelectedInstanceMap, optionalSegmentToInstanceMap);
  }

  @VisibleForTesting
  protected InstancePartitions getInstancePartitions() {
    // TODO: Evaluate whether we need to provide support for COMPLETE partitions.
    TableType tableType = TableNameBuilder.getTableTypeFromTableName(_tableNameWithType);
    Preconditions.checkNotNull(tableType);
    InstancePartitions instancePartitions;
    if (tableType.equals(TableType.OFFLINE)) {
      instancePartitions = InstancePartitionsUtils.fetchInstancePartitions(_propertyStore,
          InstancePartitionsUtils.getInstancePartitionsName(_tableNameWithType, tableType.name()));
    } else {
      instancePartitions = InstancePartitionsUtils.fetchInstancePartitions(_propertyStore,
          InstancePartitionsUtils.getInstancePartitionsName(_tableNameWithType,
              InstancePartitionsType.CONSUMING.name()));
    }
    Preconditions.checkNotNull(instancePartitions);
    return instancePartitions;
  }
}
